"""Knowledge transfer experiments.
Train MARL on source dataset, save models, adapt to target dataset and evaluate
initial performance vs from-scratch baseline.
"""
import os
import json
import shutil
import tempfile
from marl.train import marl_training, load_dataset
from marl.agents.student import StudentAgent
from marl.agents.teacher import TeacherAgent
from marl.environments.pipeline_env import PipelineEnvironment
from marl.environments.ml_components import COMPONENT_MAP
from experiments.utils import load_dataset_safely, seed_everything


def train_source(dataset: str, episodes: int, model_dir: str):
    env = marl_training(dataset_name=dataset, episodes=episodes)
    # Models already saved inside marl_training (student/teacher paths)
    return env.get_pipeline_statistics()


def adapt_and_eval(source_dataset: str, target_dataset: str, episodes: int, model_dir: str):
    # Load target dataset and environment
    seed_everything(42)
    target_data, msg = load_dataset_safely(target_dataset)
    if target_data is None:
        raise RuntimeError(msg)
    env = PipelineEnvironment(target_data, available_components=list(COMPONENT_MAP.keys()), max_pipeline_length=8, debug=False)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n

    # Stabilize state dimension (avoid oscillations triggering rebuilds)
    try:
        env.fixed_state_dim = len(env._get_state_representation())
    except Exception:
        pass

    # Derive teacher state dimension from actual teacher state (not heuristic state_dim+action_dim)
    _ = env.reset()
    initial_teacher_state = env.get_teacher_state([])
    teacher_state_dim = len(initial_teacher_state)

    student_path = os.path.join('models', f'student_model_marl_{source_dataset}.pt')
    teacher_path = os.path.join('models', f'teacher_model_marl_{source_dataset}.pt')

    student = StudentAgent(state_dim, action_dim, {'epsilon': 0.2, 'epsilon_min': 0.05})
    teacher = TeacherAgent(teacher_state_dim, action_dim, {'epsilon': 0.2, 'epsilon_min': 0.05})

    if os.path.exists(student_path):
        try:
            student.load(student_path)
        except Exception:
            pass
    if os.path.exists(teacher_path):
        try:
            teacher.load(teacher_path)
        except Exception:
            pass

    # Run limited episodes to measure warm-start performance
    best_perf = 0
    for ep in range(episodes):
        state = env.reset()
        done = False
        pipeline = []
        student_actions = []
        teacher_interventions = []
        while not done:
            valid = env.get_filtered_actions()
            if not valid:
                break
            student_action = student.act(state, valid, env=env)
            teacher_state = env.get_teacher_state(student_actions)
            # Ensure teacher state dimension remains consistent
            if teacher_state.shape[0] != teacher_state_dim:
                # Rebuild teacher once if mismatch appears
                teacher.rebuild_networks(teacher_state.shape[0])
                teacher_state_dim = teacher_state.shape[0]
            should_intervene, teacher_action = teacher.act(teacher_state, valid, student_action, env=env)
            final_action, source = env.process_teacher_intervention(student_action, should_intervene, teacher_action)
            next_state, reward, done, info = env.step(final_action)
            student.learn(state, final_action, reward, next_state, done)
            teacher.learn(teacher_state, (should_intervene, teacher_action), reward, next_state, done)
            state = next_state
            pipeline.append(env.available_components[final_action])
            student_actions.append(student_action)
            teacher_interventions.append(should_intervene)
        perf = info.get('performance', 0)
        if perf > best_perf:
            best_perf = perf
    return {"best_perf": best_perf, **env.get_pipeline_statistics()}


def run_transfer(source: str, target: str, source_episodes: int = 200, target_episodes: int = 50):
    print(f"Training on source dataset: {source}")
    source_stats = train_source(source, source_episodes, 'models')
    print(f"Adapting to target dataset: {target}")
    transfer_stats = adapt_and_eval(source, target, target_episodes, 'models')
    return {"source_dataset": source, "target_dataset": target, "source_stats": source_stats, "transfer_stats": transfer_stats}

if __name__ == '__main__':
    import argparse
    p = argparse.ArgumentParser()
    p.add_argument('--source', default='iris')
    p.add_argument('--target', default='adult')
    p.add_argument('--source_episodes', type=int, default=200)
    p.add_argument('--target_episodes', type=int, default=40)
    args = p.parse_args()
    res = run_transfer(args.source, args.target, args.source_episodes, args.target_episodes)
    print(json.dumps(res, indent=2))
